package com.iflytek.jzcpx.procuracy.web.job;

import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Lists;
import com.google.common.collect.Table;
import com.iflytek.jzcpx.procuracy.card.entity.CardFile;
import com.iflytek.jzcpx.procuracy.card.entity.CardFormField;
import com.iflytek.jzcpx.procuracy.card.service.CardFileService;
import com.iflytek.jzcpx.procuracy.card.service.CardFormFieldService;
import com.iflytek.jzcpx.procuracy.common.result.WebResult;
import com.iflytek.jzcpx.procuracy.ocr.entity.Metric;
import com.iflytek.jzcpx.procuracy.ocr.entity.OcrFile;
import com.iflytek.jzcpx.procuracy.ocr.entity.OcrTask;
import com.iflytek.jzcpx.procuracy.ocr.entity.RecognizeFile;
import com.iflytek.jzcpx.procuracy.ocr.entity.RecognizeTask;
import com.iflytek.jzcpx.procuracy.ocr.service.MetricService;
import com.iflytek.jzcpx.procuracy.ocr.service.OcrFileService;
import com.iflytek.jzcpx.procuracy.ocr.service.OcrTaskService;
import com.iflytek.jzcpx.procuracy.ocr.service.RecognizeFileService;
import com.iflytek.jzcpx.procuracy.ocr.service.RecognizeTaskService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.FastDateFormat;
import org.apache.commons.math3.util.Pair;
import org.apache.shardingsphere.driver.jdbc.core.datasource.ShardingSphereDataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

/**
 * @author <a href=mailto:ktyi@iflytek.com>伊开堂</a>
 * @date 2022/6/14
 */
@RestController
@RequestMapping("/clean")
@Api(tags = "清理数据")
public class CleanDataController {
    private static final Logger logger = LoggerFactory.getLogger(CleanDataController.class);

    @Autowired
    private CleanDataJob cleanDataJob;
    private Table<String, String, Pair<Field, BiFunction<ResultSet, String, Object>>> tableColumns;
    private Map<String, Class<?>> tableNameClass;

    @PostMapping("/data")
    @ApiOperation(value = "清理表数据[危险!!! 慎重操作!!!]", notes = "物理删除数据库表数据和关联的fastDFS文件")
    public WebResult<?> cleanData(@RequestParam String startTime, @RequestParam String endTime, @RequestParam String tableName) {
        logger.info("清理表数据, startTime: {}, endTime: {}, tableName: {}", startTime, endTime, tableName);
        int cleanCount = 0;
        try {
            cleanCount = clean(tableName, startTime, endTime, true, true);
        }
        catch (ParseException e) {
            return WebResult.failed("时间格式错误, 正确格式为: yyyy-MM-dd HH:mm:ss");
        }
        logger.info("清理表{}数据结束, 清理条数: {}", tableName, cleanCount);
        return WebResult.success("已清理" + tableName + "表中的" + cleanCount + "条数据");
    }

    private int clean(String tableName, String startTime, String endTime, boolean cleanDB, boolean cleanFdfs)
            throws ParseException {

        FastDateFormat format = FastDateFormat.getInstance("yyyy-MM-dd HH:mm:ss");
        Date startDate = null;
        if (StringUtils.isNotBlank(startTime)) {
            startDate = format.parse(startTime);
        }
        Date endDate = null;
        if (StringUtils.isNotBlank(endTime)) {
            endDate = format.parse(endTime);
        }

        return cleanDataJob.cleanTableData(tableName, startDate, endDate, cleanDB, cleanFdfs);
    }


    @PostMapping("/fdfs")
    @ApiOperation(value = "清理fastDFS文件[危险!!! 慎重操作!!!]", notes = "只删除fastDFS文件, 不删除数据库数据")
    public WebResult<?> cleanFdfs(@RequestParam String startTime, @RequestParam String endTime, @RequestParam String tableName) {
        logger.info("清理fastDFS文件, startTime: {}, endTime: {}, tableName: {}", startTime, endTime, tableName);
        int cleanCount = 0;
        try {
            cleanCount = clean(tableName, startTime, endTime, false, true);
        }
        catch (ParseException e) {
            return WebResult.failed("时间格式错误, 正确格式为: yyyy-MM-dd HH:mm:ss");
        }
        logger.info("清理fastDFS文件{}数据结束, 清理条数: {}", tableName, cleanCount);
        return WebResult.success("已清理" + tableName + "表中的" + cleanCount + "条fastDFS文件");
    }

    @Autowired
    private DataSource dataSource;
    @Autowired
    private ApplicationContext context;
    @Autowired
    private OcrTaskService ocrTaskService;
    @Autowired
    private OcrFileService ocrFileService;
    @Autowired
    private RecognizeTaskService recognizeTaskService;
    @Autowired
    private RecognizeFileService recognizeFileService;
    @Autowired
    private CardFileService cardFileService;
    @Autowired
    private CardFormFieldService cardFormFieldService;
    @Autowired
    private MetricService metricService;

    @PostConstruct
    public void init() {
        tableNameClass = new HashMap<>();
        ArrayList<Class<?>> classes = Lists.newArrayList(OcrFile.class, OcrTask.class,
                                                                              RecognizeTask.class, RecognizeFile.class,
                                                                              Metric.class, CardFormField.class,
                                                                              CardFile.class);
        tableColumns = HashBasedTable.create();
        for (Class<?> aClass : classes) {
            String tableName = aClass.getAnnotation(TableName.class).value();
            tableNameClass.put(tableName, aClass);

            // 反射读取实体 tableName:columnName:Field/Function<ResultSet,Object>
            Field[] fields = aClass.getDeclaredFields();
            for (Field field : fields) {
                TableId tableId = field.getAnnotation(TableId.class);
                if (tableId != null) {
                    String columnName = tableId.value();
                    tableColumns.put(tableName, columnName, Pair.create(field, new BiFunction<ResultSet, String, Object>() {
                        @Override
                        public Long apply(ResultSet rs, String labelName) {
                            try {
                                return rs.getLong(labelName);
                            }
                            catch (Exception e) {
                                e.printStackTrace();
                            }
                            return null;
                        }
                    }));
                }
                else {
                    TableField tableField = field.getAnnotation(TableField.class);
                    if (tableField != null) {
                        String columnName = tableField.value();
                        Class<?> fieldType = field.getType();
                        String type = fieldType.getName();
                        tableColumns.put(tableName, columnName, Pair.create(field, new BiFunction<ResultSet, String, Object>() {
                            @Override
                            public Object apply(ResultSet rs, String labelName) {
                                try {
                                    Object data = null;
                                    switch (type) {
                                        case "java.util.Date":
                                            Timestamp timestamp = rs.getTimestamp(labelName);
                                            return rs.wasNull() || timestamp == null ? null : new Date(timestamp.getTime());
                                        case "java.lang.String":
                                            data = rs.getString(labelName);
                                            break;
                                        case "java.lang.Long":
                                            data = rs.getLong(labelName);
                                            break;
                                        case "java.lang.Integer":
                                            data = rs.getInt(labelName);
                                            break;
                                    }

                                    return rs.wasNull() ? null : data;
                                }
                                catch (Exception e) {
                                    e.printStackTrace();
                                }
                                return null;
                            }
                        }));
                    }
                }
            }
        }
    }

    @GetMapping("/sharding/trans/{tableName}")
    @ApiOperation(value = "分表数据迁移", notes = "单表迁移至分表")
    public WebResult<?> transData(@PathVariable String tableName,
            @RequestParam(required = false, defaultValue = "1") int beginId,
            @RequestParam(required = false, defaultValue = "10000") int endId) throws SQLException {
        Class<?> aClass = tableNameClass.get(tableName);
        if (aClass == null) {
            return WebResult.failed("该表不支持迁移");
        }
        if (beginId < 1 || endId < 1 || beginId > endId) {
            return WebResult.failed("beginId/endId错误");
        }

        // TODO
        Constructor<?> constructor = aClass.getConstructors()[0];

        DataSource ds = null;
        try {
            ds = ((ShardingSphereDataSource) dataSource).getContextManager().getMetaDataContexts()
                                                        .getMetaDataMap().get("logic_db").getResource()
                                                        .getDataSources().get("ds0");
        }
        catch (Exception e) {
            logger.warn("数据源获取异常", e);
        }
        if (ds == null) {
            return WebResult.failed("数据源获取失败");
        }

        int step = 1000;
        int startRound = (int) Math.ceil(beginId / (float) step);
        int endRound = (int) Math.ceil(endId / (float) step);

        long count = 0, queryTotal = 0, transTotal = 0;
        PreparedStatement statement = null;
        try (Connection connection = ds.getConnection()) {
            PreparedStatement countStatement = null;
            ResultSet countRS = null;
            try {
                countStatement = connection.prepareStatement(
                        "select count(1) from " + tableName + " where id between ? and ?");
                countStatement.setLong(1, beginId);
                countStatement.setLong(2, endId);
                countRS = countStatement.executeQuery();
                if (countRS.next()) {
                    count = countRS.getLong(1);
                }
            }
            catch (SQLException e) {
                throw new RuntimeException(e);
            }
            finally {
                try {
                    if (countRS != null && !countRS.isClosed()) {
                        countRS.close();
                    }
                    if (countStatement != null && !countStatement.isClosed()) {
                        countStatement.close();
                    }
                }
                catch (Exception e) {
                    logger.warn("count sql error", e);
                }
            }
            logger.info("===== 待迁移数据{}条 ====", count);
            if (count == 0) {
                logger.warn("没有可用数据, 迁移结束.");
                return WebResult.failed("没有可用数据");
            }

            String sql = "select * from " + tableName + " where id between ? and ?";
            statement = connection.prepareStatement(sql);

            int subStart = beginId;
            int subEnd = subStart + step;
            for (int i = startRound; i <= endRound; i++) {
                subStart = (i - 1) * step + 1;
                subStart = subStart < beginId ? beginId : subStart;
                subEnd = i * step;
                subEnd = subEnd > endId ? endId : subEnd;

                statement.setLong(1, subStart);
                statement.setLong(2, subEnd);
                logger.info(">>>> 分段查询{}, id范围: {} - {}", tableName, subStart, subEnd);

                List<Object> datas = new ArrayList<>();
                int subTotal = 0;
                try (ResultSet resultSet = statement.executeQuery()) {
                    resultSet.getFetchSize();

                    while (resultSet.next()) {
                        queryTotal++;
                        subTotal++;

                        // 对象实例化
                        Object obj = null;
                        try {
                            obj = constructor.newInstance();
                        }
                        catch (Exception e) {
                            return WebResult.failed(aClass.getName() + "无法实例化");
                        }

                        // 读取查询结果 columnName : columnLabel
                        int columnCount = resultSet.getMetaData().getColumnCount();
                        for (int j = 1; j <= columnCount; j++) {
                            String columnName = resultSet.getMetaData().getColumnName(j);
                            String columnLabel = resultSet.getMetaData().getColumnLabel(j);

                            Pair<Field, BiFunction<ResultSet, String, Object>> pair = tableColumns.get(tableName,
                                                                                                       columnName.toLowerCase());
                            if (pair == null) {
                                logger.warn("未知的数据库字段: {}", columnName);
                                continue;
                            }


                            Field field = pair.getFirst();
                            BiFunction<ResultSet, String, Object> function = pair.getSecond();
                            // 对象赋值
                            Object data = function.apply(resultSet, columnLabel);
                            try {
                                field.setAccessible(true);
                                field.set(obj, data);
                            }
                            catch (Exception e) {
                                logger.warn(aClass.getName() + "." + field.getName() + " 赋值异常", e);
                            }
                        }

                        datas.add(obj);
                    }
                }
                catch (Exception e) {
                    logger.info("分段查询异常", e);
                }
                logger.info(">>>> 分段区间 {} - {} 查询结果数: {}", subStart, subEnd, subTotal);
                if (CollectionUtils.isEmpty(datas)) {
                    logger.info("该分段区间数据为空");
                    continue;
                }

                logger.info("++++++  向分表中插入分段数据, 条数: {}", CollectionUtils.size(datas));
                boolean saved = false;
                try {
                    switch (tableName) {
                        case "t_ocr_file":
                            List<OcrFile> ocrFiles = datas.stream().map(OcrFile.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[OcrFile] 迁移的原始数据id: {}", ocrFiles.stream().map(OcrFile::getId)
                                                                           .map(String::valueOf)
                                                                           .collect(Collectors.joining(", ")));

                            saved = ocrFileService.saveBatch(ocrFiles);
                            break;
                        case "t_ocr_task":
                            List<OcrTask> ocrTasks = datas.stream().map(OcrTask.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[OcrTask] 迁移的原始数据id: {}", ocrTasks.stream().map(OcrTask::getId)
                                                                           .map(String::valueOf)
                                                                           .collect(Collectors.joining(", ")));
                            saved = ocrTaskService.saveBatch(ocrTasks);
                            break;
                        case "t_recognize_file":
                            List<RecognizeFile> recognizeFiles = datas.stream().map(RecognizeFile.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[RecognizeFile] 迁移的原始数据id: {}", recognizeFiles.stream().map(
                                    RecognizeFile::getId).map(String::valueOf).collect(Collectors.joining(", ")));
                            saved = recognizeFileService.saveBatch(recognizeFiles);
                            break;
                        case "t_recognize_task":
                            List<RecognizeTask> recognizeTasks = datas.stream().map(RecognizeTask.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[RecognizeTask] 迁移的原始数据id: {}", recognizeTasks.stream().map(
                                    RecognizeTask::getId).map(String::valueOf).collect(Collectors.joining(", ")));
                            saved = recognizeTaskService.saveBatch(recognizeTasks);
                            break;
                        case "t_metric":
                            List<Metric> metrics = datas.stream().map(Metric.class::cast).collect(Collectors.toList());
                            logger.info("[Metric] 迁移的原始数据id: {}", metrics.stream().map(Metric::getId)
                                                                         .map(String::valueOf)
                                                                         .collect(Collectors.joining(", ")));
                            saved = metricService.saveBatch(metrics);
                            break;
                        case "t_card_file":
                            List<CardFile> cardFiles = datas.stream().map(CardFile.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[CardFile] 迁移的原始数据id: {}", cardFiles.stream().map(CardFile::getId)
                                                                             .map(String::valueOf)
                                                                             .collect(Collectors.joining(", ")));
                            saved = cardFileService.saveBatch(cardFiles);
                            break;
                        case "t_card_form_field":
                            List<CardFormField> cardFormFields = datas.stream().map(CardFormField.class::cast).collect(
                                    Collectors.toList());
                            logger.info("[CardFormField] 迁移的原始数据id: {}", cardFormFields.stream().map(
                                    CardFormField::getId).map(String::valueOf).collect(Collectors.joining(", ")));
                            saved = cardFormFieldService.saveBatch(cardFormFields);
                            break;
                    }
                }
                catch (Exception e) {
                    logger.warn("分表数据插入异常", e);
                }

                logger.info("分表插入数据{}, table: {},  id: {} - {}, size: {}", saved ? "成功" : "失败", tableName, subStart,
                            subEnd, subTotal);
                if (saved) {
                    transTotal += CollectionUtils.size(datas);
                }
            }
        }
        catch (Exception e) {
            logger.warn("迁移异常", e);
        }
        finally {
            try {
                if (statement != null && !statement.isClosed()) {
                    statement.close();
                }
            }
            catch (Exception e) {
                logger.warn("数据查询statement关闭异常", e);
            }
        }

        String log = "待迁移数据总数: " + count + ", 分段查询总结果数: " + queryTotal + ", 已迁移数量: " + transTotal;
        logger.info(log);
        return WebResult.success(log);
    }

}